Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Numba implementation of Blockwise #1015

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 4, 2024

Description

Implement Numba blockwise for Ops with up to 3 outputs (due to numba not liking tuple generators in the inner functions...)

It uses the machinery developed for RVs and Elemwise. The hard part has to do with multiple number of inputs and numba fussiness.

It also improves Blockwise shape inference based on the infer_shape of the core ops

The small cholesky benchmark I added here test runs 10x faster after this PR on my local machine.

Related Issue

Comment on lines +62 to +74
if nout == 1:
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),)
elif nout == 2:
tuple_core_shapes = (
to_fixed_tuple(core_shapes[0], core_shape_0),
to_fixed_tuple(core_shapes[1], core_shape_1),
)
else:
tuple_core_shapes = (
to_fixed_tuple(core_shapes[0], core_shape_0),
to_fixed_tuple(core_shapes[1], core_shape_1),
to_fixed_tuple(core_shapes[2], core_shape_2),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anybody has an idea on how to do this dynamically would be great. Do we have to do string generation 😭?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you opposed to cheesing it?

tuple(to_fixed_tuple(core_shapes[i], core_shape_lens[i]) for i in range(nout))

(I don't have full context)

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba doesn't support that in this context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ewww. Maybe you could try a bunch of eval statements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to go down the string generation as we do for some other Ops (like Scan). But I didn't want to :)

@ricardoV94 ricardoV94 marked this pull request as ready for review October 6, 2024 13:44
This can only be done when the output of infer_shape of the core_op depends only on the input shapes, and not their values.
Restricted to 3 outputs, due to limitations in jitting of Numba functions
Copy link

codecov bot commented Oct 7, 2024

Codecov Report

Attention: Patch coverage is 87.64045% with 11 lines in your changes missing coverage. Please review.

Project coverage is 81.74%. Comparing base (fa0ab9d) to head (31cc1e9).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/blockwise.py 76.19% 9 Missing and 1 partial ⚠️
pytensor/link/numba/dispatch/random.py 0.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1015      +/-   ##
==========================================
- Coverage   81.75%   81.74%   -0.01%     
==========================================
  Files         183      185       +2     
  Lines       47756    47816      +60     
  Branches    11620    11632      +12     
==========================================
+ Hits        39044    39089      +45     
- Misses       6519     6529      +10     
- Partials     2193     2198       +5     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/tensor/blockwise.py 84.68% <100.00%> (+2.82%) ⬆️
pytensor/tensor/rewriting/numba.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/random.py 58.97% <0.00%> (ø)
pytensor/link/numba/dispatch/blockwise.py 76.19% <76.19%> (ø)

... and 7 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorize follow-up
2 participants